{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "H9CL5JZzWuuj"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Probability Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "DcBheSMRXSq2"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VYXS33IGWx4Q"
      },
      "source": [
        "# Learnable Distributions Zoo\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/probability/examples/Learnable_Distributions_Zoo\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Learnable_Distributions_Zoo.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Learnable_Distributions_Zoo.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/probability/tensorflow_probability/examples/jupyter_notebooks/Learnable_Distributions_Zoo.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6m-yUlt3WP_n"
      },
      "source": [
        "In this colab we show various examples of building learnable (\"trainable\") distributions. (We make no effort to explain the distributions, only to show how to build them.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "LDScDeCWUuFf"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow.compat.v2 as tf\n",
        "import tensorflow_probability as tfp\n",
        "from tensorflow_probability.python.internal import prefer_static\n",
        "tfb = tfp.bijectors\n",
        "tfd = tfp.distributions\n",
        "tf.enable_v2_behavior()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "bNpridm1U44X"
      },
      "outputs": [],
      "source": [
        "event_size = 4\n",
        "num_components = 3"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vn_fp9oBV7NQ"
      },
      "source": [
        "## Learnable Multivariate Normal with Scaled Identity for `chol(Cov)`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 137
        },
        "colab_type": "code",
        "id": "ZED8qAvZU8yW",
        "outputId": "5fe791aa-0f05-42f1-a108-ccaed078585b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tfp.distributions.Independent(\"learnable_mvn_scaled_identity\", batch_shape=[4], event_shape=[4], dtype=float32)\n",
            "(\u003ctf.Variable 'Variable:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)\u003e, \u003ctf.Variable 'Variable:0' shape=(4, 1) dtype=float32, numpy=\n",
            "array([[0.],\n",
            "       [0.],\n",
            "       [0.],\n",
            "       [0.]], dtype=float32)\u003e)\n"
          ]
        }
      ],
      "source": [
        "learnable_mvn_scaled_identity = tfd.Independent(\n",
        "    tfd.Normal(\n",
        "        loc=tf.Variable(tf.zeros(event_size), name='loc'),\n",
        "        scale=tfp.util.TransformedVariable(\n",
        "            tf.ones([event_size, 1]),\n",
        "            bijector=tfb.Exp()),\n",
        "            name='scale'),\n",
        "    reinterpreted_batch_ndims=1,\n",
        "    name='learnable_mvn_scaled_identity')\n",
        "\n",
        "print(learnable_mvn_scaled_identity)\n",
        "print(learnable_mvn_scaled_identity.trainable_variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8LX4NbZ5V_UT"
      },
      "source": [
        "## Learnable Multivariate Normal with Diagonal for `chol(Cov)`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 70
        },
        "colab_type": "code",
        "id": "tH7NUglGU9Xs",
        "outputId": "391b8c1f-47a8-472d-fe3e-107c3244877c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tfp.distributions.Independent(\"learnable_mvn_diag\", batch_shape=[], event_shape=[4], dtype=float32)\n",
            "(\u003ctf.Variable 'Variable:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)\u003e, \u003ctf.Variable 'Variable:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)\u003e)\n"
          ]
        }
      ],
      "source": [
        "learnable_mvndiag = tfd.Independent(\n",
        "    tfd.Normal(\n",
        "        loc=tf.Variable(tf.zeros(event_size), name='loc'),\n",
        "        scale=tfp.util.TransformedVariable(\n",
        "            tf.ones(event_size),\n",
        "            bijector=tfb.Softplus()),  # Use Softplus...cuz why not?\n",
        "            name='scale'),\n",
        "    reinterpreted_batch_ndims=1,\n",
        "    name='learnable_mvn_diag')\n",
        "\n",
        "print(learnable_mvndiag)\n",
        "print(learnable_mvndiag.trainable_variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "q7QlKm8fWCCH"
      },
      "source": [
        "## Mixture of Multivarite Normal (spherical)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 188
        },
        "colab_type": "code",
        "id": "4kaN1OP8U_S-",
        "outputId": "76921750-e8d2-4465-cc08-9d37321d23e1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tfp.distributions.MixtureSameFamily(\"learnable_mix_mvn_scaled_identity\", batch_shape=[], event_shape=[4], dtype=float32)\n",
            "(\u003ctf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=\n",
            "array([[-0.11800266, -1.3127382 , -0.1813932 ,  0.24975719],\n",
            "       [ 0.97209346, -0.91694   ,  0.84734786, -1.3201759 ],\n",
            "       [ 0.15194772,  0.33787215, -0.4269769 , -0.26286215]],\n",
            "      dtype=float32)\u003e, \u003ctf.Variable 'Variable:0' shape=(3, 1) dtype=float32, numpy=\n",
            "array([[-4.600166],\n",
            "       [-4.600166],\n",
            "       [-4.600166]], dtype=float32)\u003e, \u003ctf.Variable 'logits:0' shape=(3,) dtype=float32, numpy=array([-0., -0., -0.], dtype=float32)\u003e)\n"
          ]
        }
      ],
      "source": [
        "learnable_mix_mvn_scaled_identity = tfd.MixtureSameFamily(\n",
        "    mixture_distribution=tfd.Categorical(\n",
        "        logits=tf.Variable(\n",
        "            # Changing the `1.` intializes with a geometric decay.\n",
        "            -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),\n",
        "            name='logits')),\n",
        "    components_distribution=tfd.Independent(\n",
        "        tfd.Normal(\n",
        "            loc=tf.Variable(\n",
        "              tf.random.normal([num_components, event_size]),\n",
        "              name='loc'),\n",
        "            scale=tfp.util.TransformedVariable(\n",
        "              10. * tf.ones([num_components, 1]),\n",
        "              bijector=tfb.Softplus()),  # Use Softplus...cuz why not?\n",
        "              name='scale'),\n",
        "        reinterpreted_batch_ndims=1),\n",
        "    name='learnable_mix_mvn_scaled_identity')\n",
        "\n",
        "print(learnable_mix_mvn_scaled_identity)\n",
        "print(learnable_mix_mvn_scaled_identity.trainable_variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0h_8Sp24WD7R"
      },
      "source": [
        "## Mixture of Multivariate Normal (spherical) with first mix weight unlearnable"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 171
        },
        "colab_type": "code",
        "id": "1oGkxQoLVOiu",
        "outputId": "579328df-2d2c-49c4-ec76-a29e049a779d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tfp.distributions.MixtureSameFamily(\"learnable_mix_mvndiag_first_fixed\", batch_shape=[], event_shape=[4], dtype=float32)\n",
            "(\u003ctf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=\n",
            "array([[-1., -1.,  1., -1.],\n",
            "       [-1., -1., -1., -1.],\n",
            "       [ 1.,  1.,  1., -1.]], dtype=float32)\u003e, \u003ctf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=\n",
            "array([[-4.600166, -4.600166, -4.600166, -4.600166],\n",
            "       [-4.600166, -4.600166, -4.600166, -4.600166],\n",
            "       [-4.600166, -4.600166, -4.600166, -4.600166]], dtype=float32)\u003e, \u003ctf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([-0.4054651, -0.8109302], dtype=float32)\u003e)\n"
          ]
        }
      ],
      "source": [
        "learnable_mix_mvndiag_first_fixed = tfd.MixtureSameFamily(\n",
        "    mixture_distribution=tfd.Categorical(\n",
        "        logits=tfp.util.TransformedVariable(\n",
        "            # Initialize logits as geometric decay.\n",
        "            -tf.math.log(1.5) * tf.range(num_components, dtype=tf.float32),\n",
        "            tfb.Pad(paddings=[[1, 0]], constant_values=0)),\n",
        "            name='logits'),\n",
        "    components_distribution=tfd.Independent(\n",
        "        tfd.Normal(\n",
        "            loc=tf.Variable(\n",
        "                # Use Rademacher...cuz why not?\n",
        "                tfp.math.random_rademacher([num_components, event_size]),\n",
        "                name='loc'),\n",
        "            scale=tfp.util.TransformedVariable(\n",
        "                10. * tf.ones([num_components, 1]),\n",
        "                bijector=tfb.Softplus()),  # Use Softplus...cuz why not?\n",
        "                name='scale'),\n",
        "        reinterpreted_batch_ndims=1),\n",
        "    name='learnable_mix_mvndiag_first_fixed')\n",
        "\n",
        "print(learnable_mix_mvndiag_first_fixed)\n",
        "print(learnable_mix_mvndiag_first_fixed.trainable_variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sOfgzpekWGcT"
      },
      "source": [
        "## Mixture of Multivariate Normal (full `Cov`)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 235
        },
        "colab_type": "code",
        "id": "Gz7kcj0zVQPv",
        "outputId": "94b05666-c72f-4ebc-e145-0f0f37faf040"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tfp.distributions.MixtureSameFamily(\"learnable_mix_mvntril\", batch_shape=[], event_shape=[4], dtype=float32)\n",
            "(\u003ctf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=\n",
            "array([[ 1.4036556 , -0.22486973,  0.8365339 ,  1.1744921 ],\n",
            "       [-0.14385273,  1.5095806 ,  0.78327304, -0.64133334],\n",
            "       [-0.22640549,  1.908316  ,  1.1216396 , -1.2109828 ]],\n",
            "      dtype=float32)\u003e, \u003ctf.Variable 'Variable:0' shape=(3, 10) dtype=float32, numpy=\n",
            "array([[-4.6011715,  0.       ,  0.       ,  0.       , -4.6011715,\n",
            "        -4.6011715,  0.       ,  0.       ,  0.       , -4.6011715],\n",
            "       [-4.6011715,  0.       ,  0.       ,  0.       , -4.6011715,\n",
            "        -4.6011715,  0.       ,  0.       ,  0.       , -4.6011715],\n",
            "       [-4.6011715,  0.       ,  0.       ,  0.       , -4.6011715,\n",
            "        -4.6011715,  0.       ,  0.       ,  0.       , -4.6011715]],\n",
            "      dtype=float32)\u003e, \u003ctf.Variable 'logits:0' shape=(3,) dtype=float32, numpy=array([-0., -0., -0.], dtype=float32)\u003e)\n"
          ]
        }
      ],
      "source": [
        "learnable_mix_mvntril = tfd.MixtureSameFamily(\n",
        "    mixture_distribution=tfd.Categorical(\n",
        "        logits=tf.Variable(\n",
        "            # Changing the `1.` intializes with a geometric decay.\n",
        "            -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),\n",
        "            name='logits')),\n",
        "    components_distribution=tfd.MultivariateNormalTriL(\n",
        "        loc=tf.Variable(tf.zeros([num_components, event_size]), name='loc'),\n",
        "        scale_tril=tfp.util.TransformedVariable(\n",
        "            10. * tf.eye(event_size, batch_shape=[num_components]),\n",
        "            bijector=tfb.FillScaleTriL()),\n",
        "            name='scale_tril'),\n",
        "    name='learnable_mix_mvntril')\n",
        "\n",
        "print(learnable_mix_mvntril)\n",
        "print(learnable_mix_mvntril.trainable_variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zU12fGeuWI29"
      },
      "source": [
        "## Mixture of Multivariate Normal (full `Cov`) with unlearnable first mix & first component"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 184
        },
        "colab_type": "code",
        "id": "NMb5hcsDVVyL",
        "outputId": "789d4d2d-23f5-406b-e3a2-cbc7ee578374"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tfp.distributions.MixtureSameFamily(\"learnable_mix_mvntril_fixed_first\", batch_shape=[], event_shape=[4], dtype=float32)\n",
            "(\u003ctf.Variable 'loc:0' shape=(2, 4) dtype=float32, numpy=\n",
            "array([[ 0.53900903, -0.17989647, -1.196744  , -1.0601326 ],\n",
            "       [ 0.46199334,  1.2968503 ,  0.20908853, -0.36455044]],\n",
            "      dtype=float32)\u003e, \u003ctf.Variable 'Variable:0' shape=(2, 10) dtype=float32, numpy=\n",
            "array([[-4.6011715,  0.       ,  0.       ,  0.       , -4.6011715,\n",
            "        -4.6011715,  0.       ,  0.       ,  0.       , -4.6011715],\n",
            "       [-4.6011715,  0.       ,  0.       ,  0.       , -4.6011715,\n",
            "        -4.6011715,  0.       ,  0.       ,  0.       , -4.6011715]],\n",
            "      dtype=float32)\u003e, \u003ctf.Variable 'logits:0' shape=(2,) dtype=float32, numpy=array([-0., -0.], dtype=float32)\u003e)\n"
          ]
        }
      ],
      "source": [
        "# Make a bijector which pads an eye to what otherwise fills a tril.\n",
        "num_tril_nonzero = lambda num_rows: num_rows * (num_rows + 1) // 2\n",
        "\n",
        "num_tril_rows = lambda nnz: prefer_static.cast(\n",
        "    prefer_static.sqrt(0.25 + 2. * prefer_static.cast(nnz, tf.float32)) - 0.5,\n",
        "    tf.int32)\n",
        "\n",
        "# TFP doesn't have a concat bijector, so we roll out our own.\n",
        "class PadEye(tfb.Bijector):\n",
        "\n",
        "  def __init__(self, tril_fn=None):\n",
        "    if tril_fn is None:\n",
        "      tril_fn = tfb.FillScaleTriL()\n",
        "    self._tril_fn = getattr(tril_fn, 'inverse', tril_fn)\n",
        "    super(PadEye, self).__init__(\n",
        "        forward_min_event_ndims=2,\n",
        "        inverse_min_event_ndims=2,\n",
        "        is_constant_jacobian=True,\n",
        "        name='PadEye')\n",
        "\n",
        "  def _forward(self, x):\n",
        "    num_rows = int(num_tril_rows(tf.compat.dimension_value(x.shape[-1])))\n",
        "    eye = tf.eye(num_rows, batch_shape=prefer_static.shape(x)[:-2])\n",
        "    return tf.concat([self._tril_fn(eye)[..., tf.newaxis, :], x],\n",
        "                     axis=prefer_static.rank(x) - 2)\n",
        "\n",
        "  def _inverse(self, y):\n",
        "    return y[..., 1:, :]\n",
        "\n",
        "  def _forward_log_det_jacobian(self, x):\n",
        "    return tf.zeros([], dtype=x.dtype)\n",
        "\n",
        "  def _inverse_log_det_jacobian(self, y):\n",
        "    return tf.zeros([], dtype=y.dtype)\n",
        "\n",
        "  def _forward_event_shape(self, in_shape):\n",
        "    n = prefer_static.size(in_shape)\n",
        "    return in_shape + prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)\n",
        "\n",
        "  def _inverse_event_shape(self, out_shape):\n",
        "    n = prefer_static.size(out_shape)\n",
        "    return out_shape - prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)\n",
        "\n",
        "\n",
        "tril_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus())\n",
        "learnable_mix_mvntril_fixed_first = tfd.MixtureSameFamily(\n",
        "  mixture_distribution=tfd.Categorical(\n",
        "      logits=tfp.util.TransformedVariable(\n",
        "          # Changing the `1.` intializes with a geometric decay.\n",
        "          -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),\n",
        "          bijector=tfb.Pad(paddings=[(1, 0)]),\n",
        "          name='logits')),\n",
        "  components_distribution=tfd.MultivariateNormalTriL(\n",
        "      loc=tfp.util.TransformedVariable(\n",
        "          tf.zeros([num_components, event_size]),\n",
        "          bijector=tfb.Pad(paddings=[(1, 0)], axis=-2),\n",
        "          name='loc'),\n",
        "      scale_tril=tfp.util.TransformedVariable(\n",
        "          10. * tf.eye(event_size, batch_shape=[num_components]),\n",
        "          bijector=tfb.Chain([tril_bijector, PadEye(tril_bijector)]),\n",
        "          name='scale_tril')),\n",
        "  name='learnable_mix_mvntril_fixed_first')\n",
        "\n",
        "\n",
        "print(learnable_mix_mvntril_fixed_first)\n",
        "print(learnable_mix_mvntril_fixed_first.trainable_variables)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Learnable Distributions Zoo",
      "private_outputs": false,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
